import numpy as np
from tqdm import tqdm
import concurrent.futures
from functools import partial
import pickle
import os
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend to prevent plot windows from appearing
from faster_caching import *
from plot_utils import *

def run_single_policy(policy_func, a_list, C, xi=None, Q=None, predictor=None, forced=None, threshold=None):
    """Run a single cache policy with the appropriate parameters."""
    if policy_func.__name__ == 'tail_optimized_LRU_cache_policy':
        return policy_func(a_list, C, xi, Q, predictor=predictor, forced=forced)
    elif policy_func.__name__ == 'LRU_cache_policy':
        return policy_func(a_list, C, forced)
    elif policy_func.__name__ == 'thre_lru_cache_policy':
        return policy_func(a_list, C, threshold, forced)
    else:
        raise ValueError(f"Unknown policy function: {policy_func.__name__}")

def process_cache_capacity(C, a_list, xi, Q, forced, percentiles):
    """Process a single cache capacity value using parallel execution for each policy."""
    
    # Define only the requested policy configurations
    policy_configs = [
        {'func': LRU_cache_policy, 'predictor': None, 'name': 'vanilla_lru', 'threshold': None},
        {'func': tail_optimized_LRU_cache_policy, 'predictor': 'None', 'name': 'lru', 'threshold': None},
        {'func': thre_lru_cache_policy, 'predictor': None, 'threshold': 1024, 'name': 'thre_lru'}
    ]
    
    # Execute each policy in parallel using ThreadPoolExecutor
    results = {}
    policy_results = {}
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Submit all policy execution tasks
        future_to_policy = {}
        for config in policy_configs:
            future = executor.submit(
                run_single_policy,
                policy_func=config['func'],
                a_list=a_list,
                C=C,
                xi=xi,
                Q=Q,
                predictor=config['predictor'],
                forced=forced,
                threshold=config['threshold']
            )
            future_to_policy[future] = config['name']
        
        # Collect results as they complete
        for future in concurrent.futures.as_completed(future_to_policy):
            policy_name = future_to_policy[future]
            try:
                policy_results[policy_name] = future.result()
            except Exception as exc:
                print(f"Policy {policy_name} generated an exception: {exc}")
                policy_results[policy_name] = None
    
    # Calculate percentiles and store results
    for p in percentiles:
        for policy_name, uncached_results in policy_results.items():
            if uncached_results is not None:
                results[f'{policy_name}_{p}'] = np.percentile(uncached_results, p)
            else:
                results[f'{policy_name}_{p}'] = np.nan  # Handle failed policy calculations
    
    # Store the capacity value for sorting results later
    results['capacity'] = C
    return results

def run_parallel_cache_evaluation(C_values, a_list, xi, Q, forced, percentiles, max_workers=None):
    """Run cache policy evaluation with two levels of parallelism."""
    
    # Create a partial function with fixed parameters
    process_fn = partial(
        process_cache_capacity,
        a_list=a_list,
        xi=xi,
        Q=Q,
        forced=forced,
        percentiles=percentiles
    )
    
    # Initialize result dictionaries
    policy_names = ['lru', 'vanilla_lru', 'thre_lru']
    results = {f"{policy}_{p}": [] for policy in policy_names for p in percentiles}
    
    # Create progress bar for total tasks
    total_tasks = len(C_values)
    progress_bar = tqdm(total=total_tasks * len(policy_names) * len(percentiles), desc="Testing cache capacities")
    
    # Process capacities in parallel
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        # Submit all capacity processing tasks
        future_to_capacity = {executor.submit(process_fn, C): C for C in C_values}
        
        # Process results as they complete
        completed_capacities = []
        for future in concurrent.futures.as_completed(future_to_capacity):
            try:
                result = future.result()
                C = result['capacity']
                completed_capacities.append(C)
                
                # Update results dictionaries
                for policy in policy_names:
                    for p in percentiles:
                        results[f"{policy}_{p}"].append(result[f'{policy}_{p}'])
                
                        # Update progress bar
                        progress_bar.update(1)
            except Exception as exc:
                print(f"Capacity processing generated an exception: {exc}")
                # Handle the error appropriately
    
    # Close the progress bar
    progress_bar.close()
    
    # Reorder results based on original C_values order
    if len(completed_capacities) == len(C_values):
        # Create a mapping from completed capacity to its position in the results
        capacity_to_position = {C: i for i, C in enumerate(completed_capacities)}
        
        # Create ordered indices
        ordered_indices = [capacity_to_position[C] for C in C_values]
        
        # Reorder all results
        for key in results:
            results[key] = [results[key][i] for i in ordered_indices]
    
    # Reformat results to match the expected return structure
    formatted_results = {}
    for policy in policy_names:
        formatted_results[policy] = {p: results[f"{policy}_{p}"] for p in percentiles}
    
    return (
        formatted_results['lru'],
        formatted_results['vanilla_lru'],
        formatted_results['thre_lru']
    )

def create_simplified_plots(C_values, lru_results, vanilla_lru_results, thre_lru_results, 
                          percentiles, title, filename):
    """Create a simplified plot with just the three policies."""
    # Define colors and markers
    colors = {
        'lru': '#7BADED',        # T-LRU None
        'vanilla_lru': '#8c564b', # vanilla LRU
        'thre_lru': '#FFD700'    # threshold LRU
    }
    
    markers = {
        'lru': 's',          # square for T-LRU None
        'vanilla_lru': 'v',  # triangle down for vanilla LRU  
        'thre_lru': 'x'      # x for threshold LRU
    }
    
    # Create subplots for each percentile
    n_plots = len(percentiles)
    fig, axes = plt.subplots(1, n_plots, figsize=(18, 2.8))
    
    if n_plots == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    # Policy display names
    policy_names = {
        'lru': 'T-LRU None',
        'vanilla_lru': 'Vanilla LRU',
        'thre_lru': 'Threshold LRU'
    }
    
    # For each percentile, plot each policy
    for i, p in enumerate(percentiles):
        ax = axes[i]
        
        # Plot each policy
        for policy, results in [('lru', lru_results), 
                              ('vanilla_lru', vanilla_lru_results), 
                              ('thre_lru', thre_lru_results)]:
            
            ax.plot(C_values, results[p], 
                    marker=markers[policy], 
                    color=colors[policy], 
                    label=policy_names[policy],
                    linewidth=2)
        
        # Configure axis
        ax.set_xlabel('Cache Capacity (tokens)')
        ax.set_ylabel(f'p{p} Uncached Tokens')
        ax.set_title(f'p{p} Percentile')
        ax.grid(True, alpha=0.3)
        
        # Add legend to the first plot
        if i == 0:
            ax.legend()
            
        # Log scale for y-axis
        ax.set_yscale('log')
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.close(fig)  # Close the figure to prevent display
    
    return fig, axes

def save_results_to_csv(C_values, lru_results, vanilla_lru_results, thre_lru_results, percentiles, save_dir):
    """Save results to CSV files for easier analysis."""
    # Create dataframes for uncached tokens and latency
    uncached_df = pd.DataFrame({'cache_capacity': C_values})
    latency_df = pd.DataFrame({'cache_capacity': C_values})
    
    # Add uncached tokens data for each policy and percentile
    for p in percentiles:
        uncached_df[f'T-LRU_p{p}'] = lru_results[p]
        uncached_df[f'Vanilla_LRU_p{p}'] = vanilla_lru_results[p]
        uncached_df[f'Threshold_LRU_p{p}'] = thre_lru_results[p]
        
        # Calculate latency for each policy and percentile using uncached_tokens_to_latency
        latency_df[f'T-LRU_p{p}'] = [uncached_tokens_to_latency(tokens) for tokens in lru_results[p]]
        latency_df[f'Vanilla_LRU_p{p}'] = [uncached_tokens_to_latency(tokens) for tokens in vanilla_lru_results[p]]
        latency_df[f'Threshold_LRU_p{p}'] = [uncached_tokens_to_latency(tokens) for tokens in thre_lru_results[p]]
    
    # Save to CSV files
    uncached_df.to_csv(f"{save_dir}/uncached_tokens.csv", index=False)
    latency_df.to_csv(f"{save_dir}/latency_seconds.csv", index=False)
    
    print(f"Results saved as CSV to {save_dir}/uncached_tokens.csv and {save_dir}/latency_seconds.csv")

# Main execution
if __name__ == "__main__":
    # Parameter settings
    C_values = [1000, 2000, 4000, 6000, 8000, 10000]  # Cache capacity values
    #xi_values = [400, 600, 800, 1000, 1200] # xi values to test
    #xi_values = [211, 533, 855, 1177, 1498] # approximately  [0.03, 0.04, 0.06, 0.08, 0.1] latency
    xi_values = [694, 1498, 2302, 3107, 7934] # 0.05, 0.1, 0.2, 0.5
    #xi_values = [2302, 3107, 7934, 15980, 32070] # 0.1, 0.2, 0.5, 1, 2
    #xi_values = [ 15980, 32070, 80343, 160798, 321707] # 0.1, 0.2, 0.5, 1, 2
    Q = 100  # Fixed Q value
    forced = 0  # Fixed forced parameter
    percentiles = [50, 90, 95, 99]  # Percentiles to evaluate
    
    # Load data
    a_list = load_data("ShareGPT")
    # Convert to dataframe
    a_list = pd.DataFrame(a_list)
    # Filter to include only conv_idx 1-200
    max_conv_idx = 200
    a_list = a_list[a_list['conv_idx'].isin(range(1, max_conv_idx + 1))]
    
    # Run experiments for each xi value
    for xi in xi_values:
        print(f"Running for xi={xi}, Q={Q}")
        set_name_run(f"xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}")
        
        # Run the parallel evaluation
        lru_results, vanilla_lru_results, thre_lru_results = run_parallel_cache_evaluation(
            C_values, 
            a_list, 
            xi, 
            Q, 
            forced, 
            percentiles,
            max_workers=None  # None means use all available CPU cores
        )
        
        # Create plot
        title = f"Cache Policy Comparison (xi={xi}, Q={Q})"
        filename = f"simplified_comparison_xi{xi}Q{Q}_maxconvidx{max_conv_idx}.png"
        
        fig, axes = create_simplified_plots(
            C_values,
            lru_results,
            vanilla_lru_results,
            thre_lru_results,
            percentiles,
            title,
            filename
        )
        
        # Save results to pickle file
        results_dict = {
            'tail_lru_none': lru_results,
            'vanilla_lru': vanilla_lru_results,
            'thre_lru': thre_lru_results
        }
        
        # Create directory if it doesn't exist
        save_dir = f"./results/xi{xi}Q{Q}forced{forced}_maxconvidx{max_conv_idx}"
        os.makedirs(save_dir, exist_ok=True)
        
        # Save results as pickle
        with open(f"{save_dir}/simplified_caching.pkl", "wb") as f:
            pickle.dump(results_dict, f)
        print(f"Results saved to {save_dir}/simplified_caching.pkl")
        
        # Save results as CSV files
        save_results_to_csv(
            C_values,
            lru_results,
            vanilla_lru_results,
            thre_lru_results,
            percentiles,
            save_dir
        )
        
        print(f"Plot saved to {filename}") 